import torch
from pbb.utils import runexp
import argparse
import numpy as np
import math 



DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(torch.cuda.is_available())

BATCH_SIZE = 128
TRAIN_EPOCHS = np.power(2,17)
DELTA = 0.025
DELTA_TEST = 0.01
PRIOR = 'rand'
SIGMAPRIOR = 0.03
PMIN = 1e-5
KL_PENALTY = 0.01
LEARNING_RATE = 1
MOMENTUM = 0
LEARNING_RATE_PRIOR = 0
MOMENTUM_PRIOR = 0
MC_SAMPLES = 1000


parser = argparse.ArgumentParser(description='PyTorch Soft Actor-Critic Args')
parser.add_argument('--number_of_width', type = int, default=5,
                    help='number_of_width')
args = parser.parse_args()

def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
    return _no_grad_trunc_normal_(tensor, mean, std, a, b)


def _no_grad_trunc_normal_(tensor, mean, std, a, b):
    def norm_cdf(x):
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    with torch.no_grad():
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)
        tensor.uniform_(l, u)
        tensor.mul_(2)
        tensor.sub_(1)
        eps = torch.finfo(tensor.dtype).eps
        tensor.clamp_(min=-(1. - eps), max=(1. - eps))
        tensor.erfinv_()
        tensor.mul_(std * math.sqrt(2.))
        tensor.add_(mean)
        tensor.clamp_(min=a, max=b)
        return tensor
    
torch.manual_seed(1634032502)
np.random.seed(1634032502)
    
Net_Width = np.power(2,args.number_of_width)
weights_mu_init_l1 = trunc_normal_(torch.Tensor(Net_Width,28*28), 0, 1, -2, 2)    
weights_mu_init_l2 = trunc_normal_(torch.Tensor(Net_Width, Net_Width), 0, 1, -2, 2)    
weights_mu_init_l3 = trunc_normal_(torch.Tensor(Net_Width, Net_Width), 0, 1, -2, 2)    
weights_mu_init_l4 = trunc_normal_(torch.Tensor(10, Net_Width), 0, 1, -2, 2)    
full_weight_bias_list = [weights_mu_init_l1,weights_mu_init_l2,weights_mu_init_l3,weights_mu_init_l4]
v0,v1,v2,v3,v4= runexp(Net_Width, 'mnist', 'fclassic', PRIOR, 'fcn', SIGMAPRIOR, PMIN, LEARNING_RATE, MOMENTUM, LEARNING_RATE_PRIOR,
                           MOMENTUM_PRIOR,delta=DELTA, kl_penalty = KL_PENALTY,delta_test=DELTA_TEST, mc_samples=MC_SAMPLES,train_epochs=TRAIN_EPOCHS,
                           device=DEVICE,perc_train=1.0, verbose=True, dropout_prob=0.2,batch_size=BATCH_SIZE,full_weight_bias_list= full_weight_bias_list)
np.save(str(Net_Width)+"_no_KL_gd.npy",[v0,v1,v2,v3,v4])
